import logging
import os
from typing import Callable, Dict, Optional, Tuple
import hydra
import torch  # noqa
import torch.nn as nn
import numpy as np
import functools
import nltk
import pickle

from omegaconf.dictconfig import DictConfig
from transformers import (
    AutoConfig,
    AutoTokenizer,
    LlamaTokenizer,
    is_torch_tpu_available,
    set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version
from transformers.utils.versions import require_version

from . import gist_llama, gist_t5
from .arguments import Arguments, global_setup
from .data import alpaca
from .data.utils import nested_select
from .gist_llama import DEBUG_LLAMA_CONFIG, GistLlamaForCausalLM
from .gist_t5 import GistT5ForConditionalGeneration
from .integrations import CustomWandbCallback, EvaluateFirstStepCallback
from .metrics import get_compute_metrics_fn
from .trainer_seq2seq import GistSeq2SeqTrainer
from .get_data import get_dataset

from transformers import PreTrainedTokenizer, Seq2SeqTrainer, GenerationConfig
from transformers.trainer_utils import EvalPrediction
import evaluate

# Will error if the minimal version of Transformers is not installed. Remove at
# your own risks.
check_min_version("4.28.0.dev0")

require_version(
    "datasets>=1.8.0",
    "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt",
)

logger = logging.getLogger(__name__)
Metrics = Dict[str, float]

def postprocess_text(preds, labels, remove_llama_padding=False):
    if remove_llama_padding:
        # XXX: This is a temporary hack because skip_special_tokens doesn't
        # seem to be working with the Llama SentencePiece tokenizer?
        preds = [pred.replace("⁇", "") for pred in preds]
        labels = [pred.replace("⁇", "") for pred in labels]

    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]

    # rougeLSum expects newline after each sentence
    preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

    return preds, labels

def compute_metrics(
    eval_preds: EvalPrediction,
    tokenizer: PreTrainedTokenizer,
    eval_metric: str = "rouge",
    output_file: Optional[str] = None,
) -> Metrics:
    results = {}
    preds = np.array(eval_preds.predictions)
    labels = np.array(eval_preds.label_ids)
    preds_copy = preds.copy()
    preds[labels == -100] = -100

    if eval_metric == "reg":
        score = nn.MSELoss()(preds, labels)
        results["score"] = score
    else:
        if isinstance(preds, tuple):
            preds = preds[0]

        # Compute ROUGE-L by decoding
        preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
        decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
        # Replace -100 in the labels as we can't decode them.
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        # Some simple post-processing
        is_llama = True
        print("dec preds", decoded_preds)
        print("dec lables", decoded_labels)

        decoded_preds, decoded_labels = postprocess_text(
            decoded_preds,
            decoded_labels,
            remove_llama_padding=is_llama,
        )
        if eval_metric == "gen":
            print(preds_copy, decoded_preds)
            results["score"] = 1
        else:
            # "bleu" "chrf"
            # "bertscore", lang = "en"
            rouge_results = evaluate.load(eval_metric).compute(
                predictions=decoded_preds, references=decoded_labels, use_stemmer=True
            )
            rouge_results = {k: round(v * 100, 4) for k, v in rouge_results.items()}
            results.update(rouge_results)

            prediction_lens = [
                np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds
            ]
            results["gen_len"] = np.mean(prediction_lens)

            #if output_file is not None:
            #    if not hasattr(eval_preds, "inputs"):
            #        raise RuntimeError("If writing to output file, need inputs")
            #    inputs = np.where(eval_preds.inputs == -100, 0, eval_preds.inputs)
            #    decoded_inputs = tokenizer.batch_decode(inputs)
            #    decoded_inputs = list(map(strip_special_tokens, decoded_inputs))
            #    decoded_preds = list(map(strip_special_tokens, decoded_preds))
            #    decoded_labels = list(map(strip_special_tokens, decoded_labels))
            #    pd.DataFrame(
            #        {
            #            "x": decoded_inputs,
            #            "y_pred": decoded_preds,
            #            "y_true": decoded_labels,
            #        }
            #    ).to_csv(output_file, index=False)
    print("results", results)
    return results

@hydra.main(config_path="conf", config_name="config")
def main(args: DictConfig) -> None:
    args: Arguments = global_setup(args)

    # Detecting last checkpoint.
    last_checkpoint = None
    if (
        os.path.isdir(args.training.output_dir)
        and args.training.do_train
        and not args.training.overwrite_output_dir
    ):
        last_checkpoint = get_last_checkpoint(args.training.output_dir)
        if last_checkpoint is None and len(os.listdir(args.training.output_dir)) > 0:
            existing_files = os.listdir(args.training.output_dir)
            logger.warning(
                (
                    "Output directory (%s) already exists and "
                    "is not empty. Existing files: %s. "
                    "Training anyways as these may just be output files."
                ),
                args.training.output_dir,
                str(existing_files),
            )
        elif (
            last_checkpoint is not None and args.training.resume_from_checkpoint is None
        ):
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To "
                "avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from "
                "scratch."
            )

    # Set seed before initializing model.
    set_seed(args.training.seed)

    config_kwargs = {
        "cache_dir": args.model.cache_dir,
        "revision": args.model.model_revision,
        "use_auth_token": True if args.model.use_auth_token else None,
    }

    if args.model.llama_debug:
        if args.model.pretrained:
            raise RuntimeError("llama_debug requires pretrained set to False")
        config = DEBUG_LLAMA_CONFIG
    elif args.model.config_name:
        config = AutoConfig.from_pretrained(args.model.config_name, **config_kwargs)
    elif args.model.model_name_or_path:
        config = AutoConfig.from_pretrained(
            args.model.model_name_or_path, **config_kwargs
        )
    else:
        raise ValueError(
            "Unlike run_clm.py, this script does not support specifying a model type "
            "from scratch. Specify args.model.model_name_or_path and set "
            "args.pretrained = False to train from scratch instead."
        )

    is_t5 = any(t in args.model.model_name_or_path.lower() for t in ("t5", "tk"))
    is_llama = any(t in args.model.model_name_or_path.lower() for t in ("llama",))

    #############################
    task_name = args.model.task_name

    if args.model.token_dict_path is not None:

        with open(args.model.token_dict_path + '/token_name_dict.pkl', 'rb') as f:
            token_name_dict = pickle.load(f)

        embedding_weights = torch.from_numpy(np.load(args.model.token_dict_path + "/embedder_weights.npy")).float()
        num_existing_tokens = embedding_weights.shape[0]
    else:
        token_name_dict = {}
        embedding_weights = None
        num_existing_tokens = 0

    train_dataset, eval_dataset, token_name_dict, num_new_tokens, update_tokens, start_markers = get_dataset(task_name, num_existing_tokens, token_name_dict, args.model.num_token_per_prompt, args.model.use_end_marker, args.model.use_scalar_encode, args.model.inverse_prompting, freeze)
    config.update({"num_new_tokens": num_new_tokens, "output_dir": args.training.output_dir, "regression_out_dim": args.model.regression_out_dim})

    #############################

    tokenizer_kwargs = {
        "cache_dir": args.model.cache_dir,
        "use_fast": args.model.use_fast_tokenizer,
        "revision": args.model.model_revision,
        "use_auth_token": True if args.model.use_auth_token else None,
    }

    if args.model.tokenizer_name:
        tokenizer = AutoTokenizer.from_pretrained(
            args.model.tokenizer_name, **tokenizer_kwargs
        )
    elif args.model.model_name_or_path:
        if is_llama:
            tokenizer = LlamaTokenizer.from_pretrained(
                args.model.model_name_or_path, **tokenizer_kwargs
            )
            tokenizer.pad_token = tokenizer.eos_token
            tokenizer.padding_side = "left"
        else:
            tokenizer = AutoTokenizer.from_pretrained(
                args.model.model_name_or_path, **tokenizer_kwargs
            )
    else:
        raise ValueError(
            "You are instantiating a new tokenizer from scratch. This is not supported "
            "by this script."
            "You can do it from another script, save it, and load it from here, using "
            "--tokenizer_name."
        )

    if is_t5:
        model_cls = GistT5ForConditionalGeneration
    elif is_llama:
        model_cls = GistLlamaForCausalLM
    else:
        raise ValueError(f"Model type {args.model.model_name_or_path} not supported")
    if args.model.pretrained:
        model = model_cls.from_pretrained(
            args.model.model_name_or_path,
            from_tf=bool(".ckpt" in args.model.model_name_or_path),
            config=config,
            cache_dir=args.model.cache_dir,
            revision=args.model.model_revision,
            use_auth_token=True if args.model.use_auth_token else None,
        )
    else:
        model = model_cls(config)

    freeze = args.model.freeze_existing_tokens
    avg_emb = model.model.augmented_embedder.original_embedder.weight.data.mean(0).clone()
    
    if freeze:
        if embedding_weights is not None:
            model.model.augmented_embedder.original_embedder = nn.Embedding.from_pretrained(torch.cat([model.model.augmented_embedder.original_embedder.weight.data, embedding_weights]))
    elif embedding_weights is not None:
        model.model.augmented_embedder.embedding = nn.Embedding.from_pretrained(torch.cat([embedding_weights, model.model.augmented_embedder.embedding.weight.data]))
    model.model.augmented_embedder.load_state_dict(torch.load(args.training.output_dir + "/augmented_embedder.pth"))

    if freeze:
        model.model.augmented_embedder.vocab_size = model.model.augmented_embedder.original_embedder.weight.data.shape[0]
        model.model.augmented_embedder.added_tokens = [model.model.augmented_embedder.vocab_size + i for i in range(num_new_tokens)]
    else:
        model.model.augmented_embedder.added_tokens = [model.model.augmented_embedder.vocab_size + i for i in range(num_existing_tokens+num_new_tokens)]

    if args.model.regression:
        model.lm_head_reg = nn.Linear(config.hidden_size, args.model.regression_out_dim, bias=False)
        model.lm_head_reg.load_state_dict(torch.load(args.training.output_dir + "/lm_head_reg.pth"))
        args.training.set_prediction_loss_only = True

    # ==== BEGIN GIST CHANGES ====
    # Check if gist token has already been added to the model (e.g. because
    # we're resuming from a checkpoint.)

    if is_t5 and len(tokenizer) == gist_t5.PRETRAINED_VOCAB_SIZE + num_new_tokens:
        assert model.shared.weight.shape[0] == gist_t5.PRETRAINED_VOCAB_SIZE + num_new_tokens
    elif is_llama and len(tokenizer) == gist_llama.PRETRAINED_VOCAB_SIZE + num_new_tokens:
        assert (
            model.model.embed_tokens.weight.shape[0]
            == gist_llama.PRETRAINED_VOCAB_SIZE + num_new_tokens
        )
        assert model.lm_head.weight.shape[0] == gist_llama.PRETRAINED_VOCAB_SIZE + num_new_tokens
    else:
        # Initialize gist token
        tokenizer.add_special_tokens({"additional_special_tokens": ["<GIST " + str(i) + ">" for i in range(num_existing_tokens + num_new_tokens)]})

    special_tokens = tokenizer.additional_special_tokens_ids
    print("Speacial token ids:", special_tokens)
    print("Token name dict:", token_name_dict)

    model = model.cuda().eval()
    input = None
    if input is not None:
        for token, actual_rep in token_name_dict.items():
            input = input.replace(token, actual_rep)

        input_ids = tokenizer.encode(input)#[1:]

        input_ids_tensor = torch.tensor(input_ids).unsqueeze(0).cuda()
        attention_mask_with_gist = (
            torch.tensor([1] * (len(input_ids))).unsqueeze(0).cuda())
        
        if True:
            prompt_len = (len(tokenizer(input)["input_ids"]) - 1)
            pidx = np.argwhere(np.array(input_ids) >= 32000).flatten()
            attn_mask = torch.zeros(1, len(input_ids),len(input_ids))
            k = 0

            for j,x in enumerate(pidx):
                if x >= len(input_ids): break
                if input_ids[x] > 32009:#in tokenized_function_token:
                    continue
                else:
                    if j % 10 == 0:
                        for y in range(x, x + 20 + prompt_len):
                            if y >= len(input_ids): break
                            attn_mask[0, y, :x] = 1
                        k += 1
                        prev_x = x

            attention_mask_with_gist2 = attn_mask.unsqueeze(1).cuda()

        gen_kwargs = {
            "input_ids": input_ids_tensor,
            "attention_mask": attention_mask_with_gist,
            "attention_mask_gist": attention_mask_with_gist2,
            #"num_beams": 5,
            #"early_stopping" : True,
            "do_sample": False,
            "manual_eval": True
        }
        #gen_kwargs = {
        #    "input_ids": input_ids_tensor,
        #    "attention_mask": attention_mask_with_gist,
        #    "do_sample": True,
        #    "top_k": 10,
        #    }
        generated_tokens = model.generate(
            max_new_tokens=32,
            #do_sample=False,
            **gen_kwargs,
        )

        output = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
        print("Output:", output)
        exit()

    if is_t5:
        data_collator = alpaca.collator.DataCollatorForAlpaca(
            tokenizer,
            model=model,
            padding="longest",
            # Chosen so that <1% of examples are truncated.
            # See data/alpaca_plus/length_stats.txt for length stats.
            max_source_length=128,
            max_target_length=256,
            # Human eval examples are longer.
            max_source_length_human=384,
            max_target_length_human=384,
            label_pad_token_id=-100,
            pad_to_multiple_of=8 if args.training.fp16 else None,
            gist_condition=args.training.gist.condition,
            num_gist_tokens=args.training.gist.num_gist_tokens,
            gist_token=special_tokens,
            pad_token=tokenizer.pad_token_id,
            add_gist_token=args.training.gist.add_gist_token,
        )
    elif is_llama:
        # This data collator variant does causal language modeling with left
        # padding.
        data_collator = alpaca.collator.DataCollatorForAlpacaCLM(
            tokenizer,
            icl_dataset=train_dataset if args.model.icl_method is not None else None,
            method=args.model.icl_method,
            num_demonstrations=args.model.icl_num_demonstrations,
            idx_dict=np.load(args.model.icl_idx_dict_path) if args.model.icl_idx_dict_path is not None else None,
            # Chosen so that <1% of examples are truncated.
            # See data/alpaca_plus/length_stats.txt for length stats.
            max_length=256 + 256 + 128+2048+2048,  # source=256; target=256
            pad_token=tokenizer.pad_token_id,
            check_correctness=True,
            token_name_dict = token_name_dict,
            update_tokens = update_tokens,
            eval_mode = True,
            start_markers=start_markers,
            num_token_per_prompt=args.model.num_token_per_prompt,
            use_scalar_encode=args.model.use_scalar_encode,
            use_end_marker=args.model.use_end_marker,
            add_ce_loss=args.model.add_ce_loss
        )
    else:
        assert False, "should be is_llama or is_t5"

    
    all_tasks = ["BA"]#list(set(list(eval_dataset["task"])))

    args.training.num_train_epochs = 1
    num_samples = 12

    manual_eval = True
    model.manual_eval = manual_eval

    for eval_task in all_tasks:
        #print("Evaluate task:", eval_task)
        #indices = [i for i, x in enumerate(list(eval_dataset["task"])) if x == eval_task][:num_samples]

        eval_dataset_subset = eval_dataset#.select(indices)

        if eval_task == "gen" or eval_task[:3] == "reg":
            eval_metric = eval_task[:3]
        else:
            eval_metric = "rouge"

        compute_metrics_task = functools.partial(
        compute_metrics, tokenizer=tokenizer, eval_metric=eval_metric
    )

        trainer = Seq2SeqTrainer(
            model=model,
            args=args.training,
            train_dataset=None,
            eval_dataset=eval_dataset_subset,
            tokenizer=tokenizer,
            data_collator=data_collator,
            compute_metrics=compute_metrics_task,
            preprocess_logits_for_metrics=None,
        )

        metrics = trainer.evaluate()

        if manual_eval:
            setting = "_kidneydebug"
            np.save(args.training.output_dir + "/in_reg"+setting+".npy",model.model.augmented_embedder.in_reg.weight.data.detach().cpu().numpy())
            np.save(args.training.output_dir + "/manual_eval_preds"+setting+".npy", model.pred_val)#test_embeds.npy", model.pred_val)
            np.save(args.training.output_dir + "/manual_eval_labels"+setting+".npy", model.true_val)
            np.save(args.training.output_dir + "/manual_eval_logits"+setting+".npy", model.pnlogit)
            exit()



if __name__ == "__main__":
    main()
